feat(gemma4): add Gemma4 26B-A4B MoE and 31B dense support#1855
feat(gemma4): add Gemma4 26B-A4B MoE and 31B dense support#1855leofan-lab wants to merge 1 commit intoTHUDM:mainfrom
Conversation
Plugin (model, provider, mbridge), HF<->Megatron converter, retool integration, and 10 unit tests. Highlights: - Heterogeneous attention. Sliding layers go through flash-attn (head_dim=256); global layers go through a PyTorch SDPA path because flash-attn 2.x doesn't support head_dim=512. - Unified CP>1 path. SDPACoreAttention._forward_cp_subseq_mask all-gathers K/V, un-zigzags to global order, and builds per-sub-sequence causal (+ sliding-window) masks from slime's zig-zag global indices. Covers both global and sliding layers. - Dual RoPE emits a single concatenated tensor (not a tuple) so it doesn't collide with Megatron's (self_attn, cross_attn) rope plumbing. - Gemma4MoELayer subclasses MoELayer to reuse dispatcher + grouped-GEMM + EP, then swaps in Gemma4Router (RMSNorm -> scale -> proj -> softmax -> topk -> renormalize -> per-expert scale; order is load-bearing). - Per-layer scalars loaded from the HF checkpoint as non-trainable buffers; rank-0 reads + broadcast to avoid O(world_size) FS hits. - Converter emits stacked 3D expert tensors for sglang and handles PP offset via get_transformer_layer_offset. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
@leofan-lab thanks for your great job. could you please share your training script? I start the SGlang server of Gemma4 26B-A4B model, raise TP error. |
sure: The error you came across for SGLang is most likely because you need to upgrade its version to include Gemma4 |
Hi, I'm currently trying to upgrade sglang to include Gemma4 support (as in upstream PR #21952). However, I encountered an error while applying the sglang patches: Could you provide some advice on how to resolve this patch issue? Any guidance would be greatly appreciated. |
|
@leofan-lab hi, thanks for your great work so far. I'm still looking forward to Slime support for Gemma 4. Do you have any plans to merge the related PR into main? |
Summary
Adds Gemma4 (26B-A4B MoE + 31B dense) model support to slime for RL training. Covers model architecture, HF↔Megatron weight conversion, retool integration, and 10 unit tests.
Test validation: parity tests across TP/PP/DP/CP/EP/Sliding Window all pass (see table below), RL training runs on Gemma4 26B-A4B MoE and 31B dense for over 200+ rollout steps using the retool recipe (see End-to-end validation below).
What's in this PR
Model plugin (
slime_plugins/models/gemma4.py,gemma4_provider.py)Heterogeneous attention. Sliding layers use
head_dim=256+ flash-attn with a(sw-1, 0)left-window mask. Global layers usehead_dim=512— flash-attn 2.7.4 doesn't support >256, so they go through a PyTorch SDPA path.Context parallelism.
SDPACoreAttention._forward_cp_subseq_maskis a unified CP>1 path for both global and sliding layers:Dual RoPE.
DualRotaryEmbeddingwraps(local, global)ropes and emits a single concatenated tensor per call. The layer slices per-layer based onis_sliding. Concat (not tuple) so Megatron's existing 2-tuple(self_attn, cross_attn)rope plumbing doesn't misread it. Global layers use partial-rotary via zeroedinv_freqtail entries (requiresapply_rope_fusion=False).MoE via Megatron's dispatcher.
Gemma4MoELayersubclassesMoELayerto reuse the alltoall dispatcher, grouped-GEMM experts, and EP sharding — but swaps inGemma4Router(no-scale RMSNorm → learnable scale → proj → softmax → topk → renormalize → per-expert scale). The renormalize-then-scale order is load-bearing and guarded bytest_router_matches_hf_reference_equation.Per-layer scalars loaded from the HF safetensors checkpoint as non-learnable buffers and applied after the FFN residual add. Rank-0 reads +
broadcast_object_listto avoid O(world_size) filesystem hits. Mandatory by default;GEMMA4_ALLOW_MISSING_LAYER_SCALARS=1downgrades to a warning.attention_k_eq_von global layers:linear_qkvemits[q, k]only withv_proj_weight == k_proj_weight, and_split_qkv_global_k_eq_vderivesV = v_norm(raw_k),K = k_norm(raw_k)without mutatingself.k_layernorm.HF↔Megatron conversion (
slime/backends/megatron_utils/megatron_to_hf/gemma4.py,slime_plugins/mbridge/gemma4.py)convert_gemma4_to_hfemits stacked 3D expert tensors(E, 2I, H)/(E, H, I)for sglang, drops the.weightsuffix on stacked keys, handles PP layer offset viaget_transformer_layer_offset.Gemma4Bridge._build_configexplicitly setsactivation_func = gelu_pytorch_tanh+bias_activation_fusion = False— Gemma uses GeGLU, not SwiGLU.test_gemma4_qkv_roundtrip.py).Retool integration (
examples/retool/generate_with_retool_gemma4.py)tokenizer.apply_chat_templateon Gemma4's native<|turn>role\n...<turn|>format instead of the hardcoded Qwen ChatML in the stock retool generate.<tool_call>{json}</tool_call>parsing contract so the sharedreward_func/postprocess_predictionsregex still works. Gemma4 instruct follows the system-prompt instruction despite its native<|tool_call>format being different.Configs (
scripts/models/gemma4-26B-A4B.sh,gemma4-31B.sh)MODEL_ARGS templates;
--swigluintentionally omitted (activation is set byget_gemma4_spec).Tests (
tests/gemma4/)10 test files:
test_gemma4_attention.py— K=V global split, sliding delegation, V=v_norm()test_gemma4_router.py— Gemma4Router equation, renorm-then-scale order, MoELayer.route adaptertest_gemma4_dual_rope.py— concat/split semantics, real Megatron RotaryEmbedding integration (CUDA-gated)test_gemma4_provider.py—_install_hooks,_load_layer_scalars, PP offset translationtest_gemma4_qkv_roundtrip.py— HF↔Mcore QKV bit-exact roundtriptest_gemma4_bridge.py— forward-only Megatron→HF conversiontest_gemma4_layer_integration.py— real layer build + forward (CUDA-gated)test_gemma4_cp_attention.py— SDPACoreAttention CP production paths (4 tests exercising the forward dispatch + zig-zag global indices)test_gemma4_hf_key_contract.py— asserts our Megatron→HF converter emits every key HF'sGemma4ForConditionalGenerationstate_dict expects; pins the sglang / HF loader contract against future drift.test_gemma4_layer_scalar_broadcast.py— 2-rank gloo test exercising the realrank-0-read + broadcast_object_listpath for layer scalars (single-process tests can't catch a regression that only affects rank>0).50/50 pass on an H200 GPU host via
pytest tests/gemma4/.Correctness: parity test results
All Gemma4 parallel dims verified via standalone parity harnesses. Summary:
force_cp_subseq_maskEnd-to-end validation
Trained three concurrent configurations via the retool recipe on dapo-math-17k. All runs: 48 H200 GPUs (16 actor + 32 rollout), global_batch_size=256, n_samples_per_prompt=8, GRPO, lr=5e-6, bf16, 0 pod restarts.
Reference
Hugging Face Transformers Gemma4 source:
SGLang Gemma4 model loader: